--- ../../swvox/swvox/csrc/include/common.cuh	2024-02-14 05:16:57.537407540 -0500
+++ ../google/swvox/swvox/csrc/include/common.cuh	2024-02-13 10:46:24.588645001 -0500
@@ -30,6 +30,14 @@
 #include <cuda.h>
 #include <cuda_runtime.h>
 
+#define MAX_TREE_DEPTH 30
+#define N_WAVELET_EVALUATIONS 30
+//for HAAR, there are 7 wavelets (LLH, LHL...)
+//for trilinear kernel, there are 8 wavelets ().
+#define MAX_WAVELET_SIZE 8
+
+#define INVALID_NODE_ID -1
+
 namespace {
 namespace device {
 
@@ -75,6 +83,8 @@
         xyz_inout[1] -= v;
         xyz_inout[2] -= w;
 
+        // printf("querying at : %p\n", (void*)&data[node_id][u][v][w][0]);
+
         const int32_t skip = child[node_id][u][v][w];
         if (skip == 0) {
             if (node_id_out != nullptr) {
@@ -89,6 +99,78 @@
     return nullptr;
 }
 
+template <typename scalar_t>
+__device__ __inline__ scalar_t* query_path_from_root(
+    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits>
+        data,
+    const torch::PackedTensorAccessor32<int32_t, 4, torch::RestrictPtrTraits>
+        child,
+    scalar_t* __restrict__ xyz_inout,
+    scalar_t* __restrict__ all_relative_pos,
+    scalar_t* __restrict__ cube_sz_out,
+    
+    // to store the path
+    scalar_t** traversal_nodes,
+
+    const int64_t tree_max_depth,
+    // was __restrict__ in original leaf query function
+    int64_t* node_id_out=nullptr) {
+
+    if (tree_max_depth > MAX_TREE_DEPTH) {
+        printf("tree_max_depth %d > MAX_TREE_DEPTH \n. This will cause issues! You need to recompile the cuda extension with bigger MAX_TREE_DEPTH", (int) tree_max_depth);
+        return nullptr;
+    }
+
+    const scalar_t N = child.size(1);
+    clamp_coord<scalar_t>(xyz_inout);
+
+    int32_t node_id = 0;
+    int32_t u, v, w;
+    cube_sz_out[0] = N;
+
+    int32_t depth = 0;  // depth index, 0 corresponds to the root
+    while (true) {
+        xyz_inout[0] *= N;
+        xyz_inout[1] *= N;
+        xyz_inout[2] *= N;
+        u = floor(xyz_inout[0]);
+        v = floor(xyz_inout[1]);
+        w = floor(xyz_inout[2]);
+        
+        xyz_inout[0] -= u;
+        xyz_inout[1] -= v;
+        xyz_inout[2] -= w;
+
+        all_relative_pos[depth * 3 + 0] = xyz_inout[0];
+        all_relative_pos[depth * 3 + 1] = xyz_inout[1];
+        all_relative_pos[depth * 3 + 2] = xyz_inout[2];
+
+        const int32_t skip = child[node_id][u][v][w];
+
+        // Store traversal node
+        // depth 0 is the root, and deepest leaf is at depth == max_depth
+
+        traversal_nodes[depth] = &data[node_id][u][v][w][0];
+        if (node_id_out != nullptr) {
+           node_id_out[depth] = node_id * int64_t(N * N * N) + u * int32_t(N * N) + v * int32_t(N) + w;
+        }
+
+        if (skip == 0) {
+            if (node_id_out != nullptr && depth + 1 < tree_max_depth + 1) {
+                node_id_out[depth + 1] = INVALID_NODE_ID;
+            }
+            return &data[node_id][u][v][w][0];
+        }
+        
+        depth += 1;
+
+        cube_sz_out[depth] = N * cube_sz_out[depth - 1];
+        node_id += skip;
+    }
+    return nullptr;
+}
+
+
 }  // namespace device
 }  // namespace
 
@@ -98,7 +180,7 @@
 #define CUDA_CHECK_ERRORS \
     cudaError_t err = cudaGetLastError(); \
     if (err != cudaSuccess) \
-            printf("Error in svox.%s : %s\n", __FUNCTION__, cudaGetErrorString(err))
+            printf("Error in swvox.%s : %s\n", __FUNCTION__, cudaGetErrorString(err))
 
 namespace {
 // Get approx number of CUDA cores
